""" learning_module.py
    Python module for training and testing models
    Developed as part of Recur project
    November 2020
"""
import sys
from dataclasses import dataclass

import torch
import numpy as np

from utils import now

# Ignore statemenst for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702),
#     Too many local variables (R0914), Missing docstring (C0116, C0115).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115


@dataclass
class OptimizerWithSched:
    """Attributes for optimizer, lr schedule, and lr warmup"""
    optimizer: "typing.Any"
    scheduler: "typing.Any"
    warmup: "typing.Any"


@dataclass
class TrainingSetup:
    model: str
    tol: int
    problem: str
    mode: int


@dataclass
class TestingSetup:
    model: str
    tol: int
    problem: str
    mode: int


def test(net, testloader, test_setup, device):
    """Function to evaluate the performance of the model
    input:
        net:        Pytorch network object
        testloader: Pytorch dataloader object
        device:     Device on which data is to be loaded (cpu or gpu)
    return
        Testing accuracy
    """
    net.eval()
    net.to(device)
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):

            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)

            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

            total += targets.size(0)

    accuracy = 100.0 * correct / total
    return accuracy


def train(net, trainloader, optimizer_obj, train_setup, device):
    """Function to perform one epoch of training
    input:
        net:            Pytorch network object
        trainloader:    Pytorch dataloader object
        optimizer:      Pytorch optimizer object
        criterion:      Loss function
    output:
        train_loss:     Float, average loss value
        acc:            Float, percentage of training data correctly labeled
    """

    net.train()
    net = net.to(device)
    optimizer = optimizer_obj.optimizer
    lr_scheduler = optimizer_obj.scheduler
    warmup_scheduler = optimizer_obj.warmup
    criterion = torch.nn.CrossEntropyLoss()

    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs).squeeze()
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*targets.size(0)
        predicted = outputs.argmax(1)
        correct += predicted.eq(targets).sum().item()
        total += targets.size(0)

    train_loss = train_loss / total
    acc = 100.0 * correct / total

    lr_scheduler.step()
    warmup_scheduler.dampen()

    return train_loss, acc


def test_with_modes(net, testloader, test_setup, device):
    """Function to evaluate the performance of the model
    input:
        net:        Pytorch network object
        testloader: Pytorch dataloader object
        device:     Device on which data is to be loaded (cpu or gpu)
    return
        Testing accuracy
    """
    net.eval()
    net.to(device)
    correct = 0
    total = 0
    thinker_at_each_iter = torch.zeros(net.iters).to(device)
    thinker_on_agreement = torch.zeros(net.iters).to(device)
    correct_for_the_first_time = torch.zeros(net.iters + 1).to(device)
    mode = test_setup.mode

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):

            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)

            predicted = outputs.argmax(1)
            correct += predicted.eq(targets).sum().item()

            if mode == 1:
                thoughts = net.thoughts
                predicted = thoughts.argmax(2)
                pred_correctly = predicted.eq(targets)

                thinker_at_each_iter += predicted.eq(targets).sum(1)
                pred_on_agreement = (-1 * torch.ones(inputs.shape[0])).to(device)
                iter_of_agreement = torch.zeros(inputs.shape[0], dtype=torch.int).to(device)
                threshold = net.iters // 2 + 1 if net.iters % 2 else net.iters // 2

                for k in range(inputs.shape[0]):
                    label_frequency = torch.zeros(net.linear.out_features)
                    j = 0
                    while pred_on_agreement[k] == -1:
                        current_pred = predicted[j, k]
                        label_frequency[current_pred] += 1
                        if label_frequency[current_pred] >= threshold or j == net.iters - 1:
                            pred_on_agreement[k] = current_pred
                            iter_of_agreement[k] = j
                        j += 1
                correct_on_agreement = pred_on_agreement.eq(targets)

                for i, wte in enumerate(iter_of_agreement.cpu().tolist()):
                    thinker_on_agreement[wte] += correct_on_agreement[i]

                pred_correctly = torch.cat([pred_correctly,
                                            torch.ones_like(pred_correctly[0]).unsqueeze(0)])
                first_correct_pred = [torch.min(torch.where(pred_correctly[:, i] > 0)[0]).item() for i
                                      in range(pred_correctly.size(1))]

                iteration, count = torch.tensor(first_correct_pred).unique(return_counts=True)
                for i, ite in enumerate(iteration):
                    correct_for_the_first_time[ite] += count[i].item()

                total += targets.size(0)

        if mode == 1:
            correct_for_the_first_time = correct_for_the_first_time[:-1]
            upper_bound = (100. * torch.cumsum(correct_for_the_first_time, 0) / total).cpu().tolist()
            correct_for_the_first_time = correct_for_the_first_time.cpu().tolist()
            raw_accuracy = (100. * thinker_at_each_iter / total).cpu().tolist()
            accuracy_on_agreement = (100. * thinker_on_agreement.cumsum(0) / total).cpu().tolist()
            raw_accuracy = raw_accuracy[:net.iters]

            return np.array([raw_accuracy, correct_for_the_first_time,
                             upper_bound, accuracy_on_agreement])

        total += targets.size(0)
        accuracy = 100.0 * correct / total

    return accuracy


def train_with_modes(net, trainloader, optimizer_obj, train_setup, device):
    """Function to perform one epoch of training
    input:
        net:            Pytorch network object
        trainloader:    Pytorch dataloader object
        optimizer:      Pytorch optimizer object
        criterion:      Loss function
    output:
        train_loss:     Float, average loss value
        acc:            Float, percentage of training data correctly labeled
    """

    net.train()
    net = net.to(device)
    optimizer = optimizer_obj.optimizer
    lr_scheduler = optimizer_obj.scheduler
    warmup_scheduler = optimizer_obj.warmup
    tol, problem, mode = train_setup.tol, train_setup.problem, train_setup.mode

    if problem == "classification":
        criterion = torch.nn.CrossEntropyLoss()
    else:
        criterion = torch.nn.MSELoss()

    train_loss = 0
    total_thinker_loss = 0
    correct = 0
    total = 0
    num_correct = torch.zeros(net.iters)

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs).squeeze()
        thoughts = net.thoughts

        if problem == "classification":
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
        else:
            predicted = outputs.squeeze()
            correct += (torch.abs(predicted - targets) < tol).sum().item()

        if mode == 0:
            loss = criterion(outputs, targets)
        elif mode == 1:
            for i in range(net.iters):
                predicted = thoughts[i].argmax(1)
                num_correct[i] += predicted.eq(targets).sum().item()

            predicted = torch.argmax(thoughts, dim=2)
            thinker_losses = torch.zeros(min(net.iters, 3)).to(device)

            for i in range(min(net.iters, 3)):
                thinker_losses[i] = criterion(thoughts[-(i + 1)], targets)

            thinker_loss = thinker_losses.sum()
            loss = thinker_loss
            total_thinker_loss += thinker_loss.item()
        else:
            print(now(), f"mode {mode} not yet implemented, exiting.")
            sys.exit()

        loss.backward()
        optimizer.step()

        train_loss += loss.item()*targets.size(0)
        total += targets.size(0)

    train_loss = train_loss / total
    total_thinker_loss = (total_thinker_loss * targets.size(0)) / total
    lr_scheduler.step()
    warmup_scheduler.dampen()

    if mode == 0:
        acc = 100.0 * correct / total
        return train_loss, acc
    else:
        acc = 100.0 * num_correct / total
        return np.array([train_loss, total_thinker_loss]), acc.detach().cpu().numpy()


# def test_segment(net, testloader, test_setup, device):
#     """Function to evaluate the performance of the model
#     input:
#         net:        Pytorch network object
#         testloader: Pytorch dataloader object
#         device:     Device on which data is to be loaded (cpu or gpu)
#     return
#         Testing accuracy
#     """
#     net.eval()
#     net.to(device)
#     correct = 0
#     total = 0
#
#     with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(testloader):
#
#             inputs, targets = inputs.to(device), targets.to(device)[:, 0, :, :].long()
#             outputs = net(inputs)
#
#             predicted = outputs.argmax(1) * inputs.max(1)[0]
#             correct += torch.amin(predicted == targets, dim=[1, 2]).sum().item()
#             total += targets.size(0)
#
#     accuracy = 100.0 * correct / total
#     return accuracy


def test_segment_with_modes(net, testloader, test_setup, device):
    """Function to evaluate the performance of the model
    input:
        net:        Pytorch network object
        testloader: Pytorch dataloader object
        device:     Device on which data is to be loaded (cpu or gpu)
    return
        Testing accuracy
    """
    net.eval()
    net.to(device)
    tol, problem, mode = test_setup.tol, test_setup.problem, test_setup.mode
    correct = 0
    total = 0
    thinker_at_each_iter = torch.zeros(net.iters).to(device)
    thinker_on_agreement = torch.zeros(net.iters).to(device)
    correct_for_the_first_time = torch.zeros(net.iters + 1).to(device)
    file = open("./agreement_deeper_small_mazes.txt", "w+")
    file.write("index,exit_iter,correct \n")
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):

            inputs, targets = inputs.to(device), targets.to(device)[:, 0, :, :].long()
            outputs = net(inputs)

            if problem.lower() != "segment":
                raise ValueError("Problem type wrongly defined")

            predicted = outputs.argmax(1) * inputs.max(1)[0]
            correct += torch.amin(predicted == targets, dim=[1, 2]).sum().item()
            total += targets.size(0)

            if mode == 1:
                thoughts = net.thoughts
                predicted = thoughts.argmax(2) * inputs.max(1)[0]
                pred_correctly = torch.amin(predicted == targets, dim=[2, 3])

                thinker_at_each_iter += torch.amin(predicted == targets, dim=[2, 3]).sum(1)
                pred_on_agreement = torch.zeros(inputs.shape[0], inputs.shape[2], inputs.shape[3]).to(device)
                iter_of_agreement = torch.zeros(inputs.shape[0], dtype=torch.int).to(device)
                # threshold = min(3, (net.iters - 0.5) // 2 + 1)
                threshold = 2

                for k in range(inputs.shape[0]):
                    label_frequency = torch.ones(net.iters)
                    j = 1
                    while j < net.iters:
                        current_pred = predicted[j, k]
                        for l in range(j):

                            if torch.eq(predicted[l,k], predicted[j,k]).sum() > \
                                    0.9999*inputs.shape[2]*inputs.shape[3]:
                                label_frequency[l] += 1
                                break
                        if label_frequency[l] >= threshold or j == net.iters - 1:
                            pred_on_agreement[k] = current_pred
                            iter_of_agreement[k] = j
                            break
                        j += 1
                correct_on_agreement = torch.amin(pred_on_agreement == targets, dim=[1, 2])

                for i, wte in enumerate(iter_of_agreement.cpu().tolist()):
                    thinker_on_agreement[wte] += correct_on_agreement[i]
                    full_text = f"{batch_idx}, {iter_of_agreement[i]}, {correct_on_agreement[i]} \n"
                    file.write(full_text)

                pred_correctly = torch.cat([pred_correctly,
                                            torch.ones_like(pred_correctly[0]).unsqueeze(0)])
                first_correct_pred = [torch.min(torch.where(pred_correctly[:, i] > 0)[0]).item() for i
                                      in range(pred_correctly.size(1))]

                iteration, count = torch.tensor(first_correct_pred).unique(return_counts=True)
                for i, ite in enumerate(iteration):
                    correct_for_the_first_time[ite] += count[i].item()

        if mode == 1:
            correct_for_the_first_time = correct_for_the_first_time[:-1]
            upper_bound = (100. * torch.cumsum(correct_for_the_first_time, 0) / total).cpu().tolist()
            correct_for_the_first_time = correct_for_the_first_time.cpu().tolist()
            raw_accuracy = (100. * thinker_at_each_iter / total).cpu().tolist()
            accuracy_on_agreement = (100. * thinker_on_agreement.cumsum(0) / total).cpu().tolist()
            raw_accuracy = raw_accuracy[:net.iters]

            return np.array([raw_accuracy, correct_for_the_first_time,
                             upper_bound, accuracy_on_agreement])

        total += targets.size(0)
        accuracy = 100.0 * correct / total

    return accuracy


def train_segment(net, trainloader, optimizer_obj, train_setup, device):
    """Function to perform one epoch of training
    input:
        net:            Pytorch network object
        trainloader:    Pytorch dataloader object
        optimizer:      Pytorch optimizer object
        criterion:      Loss function
    output:
        train_loss:     Float, average loss value
        acc:            Float, percentage of training data correctly labeled
    """

    net.train()
    net = net.to(device)
    optimizer = optimizer_obj.optimizer
    lr_scheduler = optimizer_obj.scheduler
    warmup_scheduler = optimizer_obj.warmup

    criterion = torch.nn.CrossEntropyLoss(reduction="none")

    train_loss = 0
    correct = 0
    total = 0
    total_pixels = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)[:, 0, :, :].long()
        optimizer.zero_grad()
        outputs = net(inputs)

        n, c, h, w = outputs.size()
        reshaped_outputs = outputs.transpose(1, 2).transpose(2, 3).contiguous()
        reshaped_outputs = reshaped_outputs[targets.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
        reshaped_outputs = reshaped_outputs.view(-1, c)

        reshaped_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous()
        reshaped_inputs = reshaped_inputs.mean(3).unsqueeze(-1)
        reshaped_inputs = reshaped_inputs[targets.view(n, h, w, 1).repeat(1, 1, 1, 1) >= 0]
        reshaped_inputs = reshaped_inputs.view(-1, 1)
        path_mask = (reshaped_inputs > 0).squeeze()

        mask = targets >= 0.0
        reshaped_targets = targets[mask]

        loss = criterion(reshaped_outputs, reshaped_targets)
        loss = loss[path_mask].mean()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * path_mask.size(0)
        total_pixels += path_mask.size(0)

        predicted = outputs.argmax(1) * inputs.max(1)[0]
        correct += torch.amin(predicted == targets, dim=[1, 2]).sum().item()
        total += targets.size(0)

    train_loss = train_loss / total_pixels
    acc = 100.0 * correct / total
    lr_scheduler.step()
    warmup_scheduler.dampen()

    return train_loss, acc


def train_segment_with_modes(net, trainloader, optimizer_obj, train_setup, device):
    """Function to perform one epoch of training
    input:
        net:            Pytorch network object
        trainloader:    Pytorch dataloader object
        optimizer:      Pytorch optimizer object
        criterion:      Loss function
    output:
        train_loss:     Float, average loss value
        acc:            Float, percentage of training data correctly labeled
    """

    net.train()
    net = net.to(device)
    optimizer = optimizer_obj.optimizer
    lr_scheduler = optimizer_obj.scheduler
    warmup_scheduler = optimizer_obj.warmup
    mode = train_setup.mode
    criterion = torch.nn.CrossEntropyLoss(reduction="none")

    train_loss = 0
    correct = 0
    total = 0
    total_pixels = 0
    num_correct = torch.zeros(net.iters)

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)[:, 0, :, :].long()
        optimizer.zero_grad()
        outputs = net(inputs)

        if mode == 0:
            n, c, h, w = outputs.size()
            reshaped_outputs = outputs.transpose(1, 2).transpose(2, 3).contiguous()
            reshaped_outputs = reshaped_outputs[targets.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
            reshaped_outputs = reshaped_outputs.view(-1, c)

            reshaped_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous()
            reshaped_inputs = reshaped_inputs.mean(3).unsqueeze(-1)
            reshaped_inputs = reshaped_inputs[targets.view(n, h, w, 1).repeat(1, 1, 1, 1) >= 0]
            reshaped_inputs = reshaped_inputs.view(-1, 1)
            path_mask = (reshaped_inputs > 0).squeeze()

            mask = targets >= 0.0
            reshaped_targets = targets[mask]

            loss = criterion(reshaped_outputs, reshaped_targets)
            loss = loss[path_mask].mean()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * path_mask.size(0)
            total_pixels += path_mask.size(0)

            predicted = outputs.argmax(1) * inputs.max(1)[0]
            correct += torch.amin(predicted == targets, dim=[1, 2]).sum().item()

        elif mode == 1:
            thoughts = net.thoughts
            n, c, h, w = thoughts[-1].size()
            reshaped_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous()
            reshaped_inputs = reshaped_inputs.mean(3).unsqueeze(-1)
            reshaped_inputs = reshaped_inputs[targets.view(n, h, w, 1).repeat(1, 1, 1, 1) >= 0]
            reshaped_inputs = reshaped_inputs.view(-1, 1)
            path_mask = (reshaped_inputs > 0).squeeze()
            mask = targets >= 0.0
            reshaped_targets = targets[mask]
            iterations = min(3, net.iters)
            loss = 0

            for i in range(net.iters):
                outputs = thoughts[i]
                predicted = outputs.argmax(1) * inputs.max(1)[0]
                num_correct[i] += torch.amin(predicted == targets, dim=[1, 2]).sum().item()

            for i in range(iterations):
                outputs = thoughts[-(i+1)]
                n, c, h, w = outputs.size()
                reshaped_outputs = outputs.transpose(1, 2).transpose(2, 3).contiguous()
                reshaped_outputs = reshaped_outputs[targets.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
                reshaped_outputs = reshaped_outputs.view(-1, c)
                temp_loss = criterion(reshaped_outputs, reshaped_targets)
                loss += temp_loss[path_mask].mean()
                total_pixels += path_mask.size(0)

            loss.backward()
            optimizer.step()
            train_loss += loss.item() * path_mask.size(0)

        total += targets.size(0)

    train_loss = train_loss / total_pixels
    lr_scheduler.step()
    warmup_scheduler.dampen()

    if mode == 0:
        acc = 100.0 * correct / total
        return train_loss, acc
    if mode == 1:
        acc = 100.0 * num_correct / total
        return np.array([train_loss, train_loss]), acc.detach().cpu().numpy()

def test_segment(net, testloader, test_setup, device):
    """Function to evaluate the performance of the model
    input:
        net:        Pytorch network object
        testloader: Pytorch dataloader object
        device:     Device on which data is to be loaded (cpu or gpu)
    return
        Testing accuracy
    """
    net.eval()
    net.to(device)
    correct = 0
    confidence = torch.zeros(net.iters)
    total = 0
    total_pixels = 0
    file = open("./max_confidence_index_deeper_small_mazes.txt", "w+")
    file.write("index,exit_iter,correct \n")
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):

            inputs, targets = inputs.to(device), targets.to(device)[:, 0, :, :].long()
            outputs = net(inputs)
            confidence_array = torch.zeros(net.iters, inputs.size(0))
            for i, thought in enumerate(net.thoughts):
                conf = torch.nn.functional.softmax(thought.detach(), dim=1).max(1)[0] \
                       * inputs.max(1)[0]
                confidence[i] += conf.sum().item()
                confidence_array[i] = conf.sum([1, 2]) / inputs.max(1)[0].sum([1, 2])

            exit_iter = confidence_array.argmax(0)

            # increasing_conf = (confidence_array[2:] - confidence_array[1:-1]).sign()
            # increasing_conf[-1, :] = -1
            # exit_iter = increasing_conf.argmin(0)+2

            best_thoughts = net.thoughts[exit_iter, torch.arange(net.thoughts.size(1))].squeeze()
            if best_thoughts.shape[0] != inputs.shape[0]:
                best_thoughts = best_thoughts.unsqueeze(0)
            predicted = best_thoughts.argmax(1) * inputs.max(1)[0]
            correct += torch.amin(predicted == targets, dim=[1, 2]).sum().item()
            full_text = f"{batch_idx}, {exit_iter.item()}, {torch.amin(predicted == targets, dim=[1, 2]).sum().item()} \n"
            file.write(full_text)

            total_pixels += inputs.max(1)[0].sum().item()
            total += targets.size(0)

    accuracy = 100.0 * correct / total
    average_confs = confidence / total_pixels
    print(f"Average per pixel confidence: {average_confs}")
    return accuracy


# def test_segment(net, testloader, test_setup, device):
#     """Function to evaluate the performance of the model
#     input:
#         net:        Pytorch network object
#         testloader: Pytorch dataloader object
#         device:     Device on which data is to be loaded (cpu or gpu)
#     return
#         Testing accuracy
#     """
#     net.eval()
#     net.to(device)
#     correct = torch.zeros(net.iters)
#     total = 0
#
#     with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(testloader):
#
#             inputs, targets = inputs.to(device), targets.to(device)[:, 0, :, :].long()
#             outputs = net(inputs)
#             for i, thought in enumerate(net.thoughts):
#                 predicted = thought.argmax(1) * inputs.max(1)[0]
#                 correct[i] += torch.amin(predicted==targets, dim=[1, 2]).sum().item()
#             total += targets.size(0)
#
#     accuracy = 100.0 * correct / total
#     return list(accuracy.numpy())
